{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/W2D1-postcourse-bugfix/tutorials/W2D3_DecisionMaking/student/W2D3_Tutorial3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Neuromatch Academy: Week 2, Day 3, Tutorial 3\n",
    "# Linear Dynamical Systems & The Kalman Filter\n",
    "__Content creators:__ Caroline Haimerl and Byron Galbraith\n",
    "\n",
    "__Content reviewers:__ Jesse Livezey, Matt Krause, and Michael Waskom\n",
    "\n",
    "**Useful reference:**\n",
    "- Roweis, Ghahramani (1998): A unifying review of linear Gaussian Models\n",
    "- Bishop (2006): Pattern Recognition and Machine Learning\n",
    "\n",
    "**Acknowledgement**\n",
    "\n",
    "This tutorial is in part based on code originally created by Caroline Haimerl for Dr. Cristina Savin's *Probabilistic Time Series* class at the Center for Data Science, New York University"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 519
    },
    "colab_type": "code",
    "outputId": "d75a9b05-d43e-4a8a-968e-6f5c0ecdddae"
   },
   "outputs": [],
   "source": [
    "#@title Video 1: Introduction\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"6f_51L3i5aQ\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Tutorial Objectives\n",
    "\n",
    "In the previous tutorials we looked at inferring discrete latent states that give rise to our measurements. In this tutorial, we will learn how to infer a latent model when our states are continuous. Particular attention is paid to the Kalman filter and it's mathematical foundation.\n",
    "\n",
    "In this tutorial, you will:\n",
    "* Review linear dynamical systems\n",
    "* Learn about and implement the Kalman filter\n",
    "* Explore how the Kalman filter can be used to smooth data from an eye-tracking experiment\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Install PyKalman (https://pykalman.github.io/)\n",
    "!pip install pykalman --quiet\n",
    "\n",
    "# Imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pykalman\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "#@title Figure settings\n",
    "import ipywidgets as widgets       # interactive display\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "#@title Data retrieval and loading\n",
    "import io\n",
    "import os\n",
    "import hashlib\n",
    "import requests\n",
    "\n",
    "fname = \"W2D3_mit_eyetracking_2009.npz\"\n",
    "url = \"https://osf.io/jfk8w/download\"\n",
    "expected_md5 = \"20c7bc4a6f61f49450997e381cf5e0dd\"\n",
    "\n",
    "if not os.path.isfile(fname):\n",
    "  try:\n",
    "    r = requests.get(url)\n",
    "  except requests.ConnectionError:\n",
    "    print(\"!!! Failed to download data !!!\")\n",
    "  else:\n",
    "    if r.status_code != requests.codes.ok:\n",
    "      print(\"!!! Failed to download data !!!\")\n",
    "    elif hashlib.md5(r.content).hexdigest() != expected_md5:\n",
    "      print(\"!!! Data download appears corrupted !!!\")\n",
    "    else:\n",
    "      with open(fname, \"wb\") as fid:\n",
    "        fid.write(r.content)\n",
    "\n",
    "def load_eyetracking_data(data_fname=fname):\n",
    "\n",
    "  with np.load(data_fname, allow_pickle=True) as dobj:\n",
    "    data = dict(**dobj)\n",
    "\n",
    "  images = [plt.imread(io.BytesIO(stim), format='JPG')\n",
    "            for stim in data['stimuli']]\n",
    "  subjects = data['subjects']\n",
    "\n",
    "  return subjects, images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "#@title Helper functions\n",
    "np.set_printoptions(precision=3)\n",
    "\n",
    "\n",
    "def plot_kalman(state, observation, estimate=None, label='filter', color='r-',\n",
    "                title='LDS', axes=None):\n",
    "    if axes is None:\n",
    "      fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 6))\n",
    "      ax1.plot(state[:, 0], state[:, 1], 'g-', label='true latent')\n",
    "      ax1.plot(observation[:, 0], observation[:, 1], 'k.', label='data')\n",
    "    else:\n",
    "      ax1, ax2 = axes\n",
    "\n",
    "    if estimate is not None:\n",
    "      ax1.plot(estimate[:, 0], estimate[:, 1], color=color, label=label)\n",
    "    ax1.set(title=title, xlabel='X position', ylabel='Y position')\n",
    "    ax1.legend()\n",
    "\n",
    "    if estimate is None:\n",
    "      ax2.plot(state[:, 0], observation[:, 0], '.k', label='dim 1')\n",
    "      ax2.plot(state[:, 1], observation[:, 1], '.', color='grey', label='dim 2')\n",
    "      ax2.set(title='correlation', xlabel='latent', ylabel='observed')\n",
    "    else:\n",
    "      ax2.plot(state[:, 0], estimate[:, 0], '.', color=color,\n",
    "               label='latent dim 1')\n",
    "      ax2.plot(state[:, 1], estimate[:, 1], 'x', color=color,\n",
    "               label='latent dim 2')\n",
    "      ax2.set(title='correlation',\n",
    "              xlabel='real latent',\n",
    "              ylabel='estimated latent')\n",
    "    ax2.legend()\n",
    "\n",
    "    return ax1, ax2\n",
    "\n",
    "\n",
    "def plot_gaze_data(data, img=None, ax=None):\n",
    "    # overlay gaze on stimulus\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(8, 6))\n",
    "\n",
    "    xlim = None\n",
    "    ylim = None\n",
    "    if img is not None:\n",
    "        ax.imshow(img, aspect='auto')\n",
    "        ylim = (img.shape[0], 0)\n",
    "        xlim = (0, img.shape[1])\n",
    "\n",
    "    ax.scatter(data[:, 0], data[:, 1], c='m', s=100, alpha=0.7)\n",
    "    ax.set(xlim=xlim, ylim=ylim)\n",
    "\n",
    "    return ax\n",
    "\n",
    "\n",
    "def plot_kf_state(kf, data, ax):\n",
    "    mu_0 = np.ones(kf.n_dim_state)\n",
    "    mu_0[:data.shape[1]] = data[0]\n",
    "    kf.initial_state_mean = mu_0\n",
    "\n",
    "    mu, sigma = kf.smooth(data)\n",
    "    ax.plot(mu[:, 0], mu[:, 1], 'limegreen', linewidth=3, zorder=1)\n",
    "    ax.scatter(mu[0, 0], mu[0, 1], c='orange', marker='>', s=200, zorder=2)\n",
    "    ax.scatter(mu[-1, 0], mu[-1, 1], c='orange', marker='s', s=200, zorder=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 1: Linear Dynamical System (LDS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 519
    },
    "colab_type": "code",
    "outputId": "b12e11b3-17bc-4177-9cff-55193ef3980f"
   },
   "outputs": [],
   "source": [
    "#@title Video 2: Linear Dynamical Systems\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"2SWh639YgEg\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Latent state variable: $$s_t = Fs_{t-1}+\\zeta_t$$\n",
    "\n",
    "Measured/observed variable: $$y_t = Hs_{t}+\\eta_t$$\n",
    "\n",
    "The latent state variable has dimension $D$ and the measured variable dimension $N$, dimensionality reduction here means that $D<N$.\n",
    "\n",
    "Both latent and measured variable have Gaussian noise terms:\n",
    "\n",
    "\\begin{eqnarray}\n",
    "\\zeta_t & \\sim & N(0, Q) \\\\\n",
    "\\eta_t & \\sim & N(0, R) \\\\\n",
    "s_0 & \\sim & N(\\mu_0, \\Sigma_0)\n",
    "\\end{eqnarray}\n",
    "\n",
    "As a consequence, $s_t$, $y_t$ and their joint distributions are Gaussian so we can easily compute the marginals and conditionals.\n",
    "\n",
    "Just as in the HMM, the structure is that of a Markov chain where the state at time point $t$ is conditionally independent of previous states given the state at time point $t-1$.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 1.1: Sampling\n",
    "\n",
    "The first thing we will investigate is how to generate timecourse samples from a linear dynamical system given its parameters. We will start by defining the following system:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# task dimensions\n",
    "n_dim_state = 2\n",
    "n_dim_obs = 2\n",
    "\n",
    "# initialize model parameters\n",
    "params = {\n",
    "  'F': 0.5 * np.eye(n_dim_state),  # state transition matrix\n",
    "  'Q': np.eye(n_dim_obs),  # state noise covariance\n",
    "  'H': np.eye(n_dim_state),  # observation matrix\n",
    "  'R': 0.1 * np.eye(n_dim_obs),  # observation noise covariance\n",
    "  'mu_0': np.zeros(n_dim_state),  # initial state mean\n",
    "  'sigma_0': 0.1 * np.eye(n_dim_state),  # initial state noise covariance\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Note**: We used a parameter dictionary `params` above. As the number of parameters we need to provide to our functions increases, it can be beneficial to condense them into a data structure like this to clean up the number of inputs we pass in. The trade-off is that we have to know what is in our data structure to use those values, rather than looking at the function signature directly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Exercise 1: Sampling from a linear dynamical system\n",
    "\n",
    "In this exercise you will implement the dynamics functions of a linear dynamical system to sample both a latent space trajectory (given parameters set above) and noisy measurements.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def sample_lds(n_timesteps, params, seed=0):\n",
    "  \"\"\" Generate samples from a Linear Dynamical System specified by the provided\n",
    "  parameters.\n",
    "\n",
    "  Args:\n",
    "  n_timesteps (int): the number of time steps to simulate\n",
    "  params (dict): a dictionary of model paramters: (F, Q, H, R, mu_0, sigma_0)\n",
    "  seed (int): a random seed to use for reproducibility checks\n",
    "\n",
    "  Returns:\n",
    "  ndarray, ndarray: the generated state and observation data\n",
    "  \"\"\"\n",
    "  n_dim_state = params['F'].shape[0]\n",
    "  n_dim_obs = params['H'].shape[0]\n",
    "\n",
    "  # set seed\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  # precompute random samples from the provided covariance matrices\n",
    "  # mean defaults to 0\n",
    "  zi = stats.multivariate_normal(cov=params['Q']).rvs(n_timesteps)\n",
    "  eta = stats.multivariate_normal(cov=params['R']).rvs(n_timesteps)\n",
    "\n",
    "  # initialize state and observation arrays\n",
    "  state = np.zeros((n_timesteps, n_dim_state))\n",
    "  obs = np.zeros((n_timesteps, n_dim_obs))\n",
    "\n",
    "  ###################################################################\n",
    "  ## TODO for students: compute the next state and observation values\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student excercise: compute the next state and observation values\")\n",
    "  ###################################################################\n",
    "\n",
    "  # simulate the system\n",
    "  for t in range(n_timesteps):\n",
    "    # write the expressions for computing state values given the time step\n",
    "    if t == 0:\n",
    "      state[t] = ...\n",
    "    else:\n",
    "      state[t] = ...\n",
    "\n",
    "    # write the expression for computing the observation\n",
    "    obs[t] = ...\n",
    "\n",
    "  return state, obs\n",
    "\n",
    "\n",
    "# Uncomment below to test your function\n",
    "# state, obs = sample_lds(100, params)\n",
    "# print('sample at t=3 ', state[3])\n",
    "# plot_kalman(state, obs, title='sample')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 482
    },
    "colab_type": "text",
    "outputId": "d51c1144-ff76-4d24-c06a-02086c19378b"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W2D3_DecisionMaking/solutions/W2D3_Tutorial3_Solution_8cfee88d.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=1133 height=414 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W2D3_DecisionMaking/static/W2D3_Tutorial3_Solution_8cfee88d_1.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Interactive Demo: Adjusting System Dynamics\n",
    "To test your understanding of the parameters of a linear dynamical system, think about what you would expect if you made the following changes:\n",
    "1. Reduce observation noise $R$\n",
    "2. Increase respective temporal dynamics $F$\n",
    "\n",
    "Use the interactive widget below to vary the values of $R$ and $F$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 494,
     "referenced_widgets": [
      "b2671333e42c4a07be594f4f53c4b6c7",
      "9b9135ad6238416bbfa2aa9ceac9afac",
      "8510497ffc714be5a60f51763ac70d96",
      "f61110098d194109b0d40c0c52b1cf2f",
      "1ad4749667784b2ebd5d8de21a468764",
      "1fa269d799514f298bff7665760e5c8f",
      "03da26b0bb51475aa305b83e9b142068",
      "f9c9bcc9628941c8afcceebb3605ee4e",
      "300ef992f1224dd6ad0e900aacbdca29",
      "bf9934dbb2d443c2a9597c260428ece4"
     ]
    },
    "colab_type": "code",
    "outputId": "4f254285-617a-4bdd-d605-c2559797aaaf"
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "\n",
    "#@markdown Make sure you execute this cell to enable the widget!\n",
    "\n",
    "@widgets.interact(R=widgets.FloatLogSlider(0.1, min=-3, max=1),\n",
    "                  F=widgets.FloatSlider(0.5, min=0.0, max=1.0))\n",
    "def explore_dynamics(R=0.1, F=0.5):\n",
    "    params = {\n",
    "    'F': F * np.eye(n_dim_state),  # state transition matrix\n",
    "    'Q': np.eye(n_dim_obs),  # state noise covariance\n",
    "    'H': np.eye(n_dim_state),  # observation matrix\n",
    "    'R': R * np.eye(n_dim_obs),  # observation noise covariance\n",
    "    'mu_0': np.zeros(n_dim_state),  # initial state mean,\n",
    "    'sigma_0': 0.1 * np.eye(n_dim_state),  # initial state noise covariance\n",
    "    }\n",
    "\n",
    "    state, obs = sample_lds(100, params)\n",
    "    plot_kalman(state, obs, title='sample')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 2: Kalman Filtering\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 519
    },
    "colab_type": "code",
    "outputId": "fc6da9d9-10af-48e7-a00b-23165bee554e"
   },
   "outputs": [],
   "source": [
    "#@title Video 3: Kalman Filtering\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"VboZOV9QMOI\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We want to infer the latent state variable $s_t$ given the measured (observed) variable $y_t$.\n",
    "\n",
    "$$P(s_t|y_1, ..., y_t, y_{t+1}, ..., y_T)\\sim N(\\hat{\\mu_t}, \\hat{\\Sigma_t})$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "First we obtain estimates of the latent state by running the filtering from $n=0,....N$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "$$s_t^{pred}\\sim N(\\hat{\\mu}_t^{pred},\\hat{\\Sigma}_t^{pred})$$\n",
    "\n",
    "Where $\\hat{\\mu}_t^{pred}$ and $\\hat{\\Sigma}_t^{pred}$ are derived as follows:\n",
    "\n",
    "\\begin{eqnarray}\n",
    "\\hat{\\mu}_1^{pred} & = & F\\hat{\\mu}_{0} \\\\\n",
    "\\hat{\\mu}_t^{pred} & = & F\\hat{\\mu}_{t-1}\n",
    "\\end{eqnarray}\n",
    "\n",
    "*this is the prediction for $s_t$ obtained simply by taking the expected value of $s_{t-1}$ and projecting it forward one step using the transition probability matrix $A$*\n",
    "\n",
    "\\begin{eqnarray}\n",
    "\\hat{\\Sigma}_0^{pred} & = & F\\hat{\\Sigma}_{0}F^T+Q \\\\\n",
    "\\hat{\\Sigma}_t^{pred} & = & F\\hat{\\Sigma}_{t-1}F^T+Q\n",
    "\\end{eqnarray}\n",
    "\n",
    "*same for the covariance taking into account the noise covariance $Q$*\n",
    "\n",
    "update from observation to obtain $\\hat{\\mu}_t^{filter}$ and $\\hat{\\Sigma}_t^{filter}$\n",
    "\n",
    "project to observational space:\n",
    "$$y_t^{pred}\\sim N(H\\hat{\\mu}_t^{pred}, H\\hat{\\Sigma}_t^{pred}H^T+R)$$\n",
    "\n",
    "update prediction by actual data:\n",
    "\n",
    "\\begin{eqnarray}\n",
    "s_t^{filter} & \\sim & N(\\hat{\\mu}_t^{filter}, \\hat{\\Sigma}_t^{filter}) \\\\\n",
    "\\hat{\\mu}_t^{filter} & = & \\hat{\\mu}_t^{pred}+K_t(y_t-H\\hat{\\mu}_t^{pred}) \\\\\n",
    "\\hat{\\Sigma}_t^{filter} & = & (I-K_tH)\\hat{\\Sigma}_t^{pred}\n",
    "\\end{eqnarray}\n",
    "\n",
    "Kalman gain matrix: \n",
    "$$K_t=\\hat{\\Sigma}_t^{pred}H^T(H\\hat{\\Sigma}_t^{pred}H^T+R)^{-1}$$\n",
    "\n",
    "*we use the latent-only prediction to project it to the observational space and compute a correction proportional to the error $y_t-HFz_{t-1}$ between prediction and data, coefficient of this correction is the Kalman gain matrix*\n",
    "\n",
    "*if measurement noise is small and dynamics are fast -> estimation will depend mostly on observed data*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "In order to explore the impact of filtering, we will use the following noisy periodic system:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 465
    },
    "colab_type": "code",
    "outputId": "b94ffc1d-b996-4b7d-a0ab-d85f3edff3fc"
   },
   "outputs": [],
   "source": [
    "# task dimensions\n",
    "n_dim_state = 2\n",
    "n_dim_obs = 2\n",
    "\n",
    "# initialize model parameters\n",
    "params = {\n",
    "  'F': np.array([[1., 1.], [-(2*np.pi/20.)**2., .9]]),  # state transition matrix\n",
    "  'Q': np.eye(n_dim_obs),  # state noise covariance\n",
    "  'H': np.eye(n_dim_state),  # observation matrix\n",
    "  'R': 1.0 * np.eye(n_dim_obs),  # observation noise covariance\n",
    "  'mu_0': np.zeros(n_dim_state),  # initial state mean\n",
    "  'sigma_0': 0.1 * np.eye(n_dim_state),  # initial state noise covariance\n",
    "}\n",
    "\n",
    "state, obs = sample_lds(100, params)\n",
    "plot_kalman(state, obs, title='sample')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Exercise 2: Implement Kalman filtering\n",
    "In this exercise you will implement the Kalman filter (forward) process. Your focus will be on writing the expressions for the Kalman gain, filter mean, and filter covariance at each time step (refer to the equations above)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def kalman_filter(data, params):\n",
    "  \"\"\" Perform Kalman filtering (forward pass) on the data given the provided\n",
    "  system parameters.\n",
    "\n",
    "  Args:\n",
    "    data (ndarray): a sequence of osbervations of shape(n_timesteps, n_dim_obs)\n",
    "    params (dict): a dictionary of model paramters: (F, Q, H, R, mu_0, sigma_0)\n",
    "\n",
    "  Returns:\n",
    "    ndarray, ndarray: the filtered system means and noise covariance values\n",
    "  \"\"\"\n",
    "  # pulled out of the params dict for convenience\n",
    "  F = params['F']\n",
    "  Q = params['Q']\n",
    "  H = params['H']\n",
    "  R = params['R']\n",
    "\n",
    "  n_dim_state = F.shape[0]\n",
    "  n_dim_obs = H.shape[0]\n",
    "  I = np.eye(n_dim_state)  # identity matrix\n",
    "\n",
    "  # state tracking arrays\n",
    "  mu = np.zeros((len(data), n_dim_state))\n",
    "  sigma = np.zeros((len(data), n_dim_state, n_dim_state))\n",
    "\n",
    "  # filter the data\n",
    "  for t, y in enumerate(data):\n",
    "    if t == 0:\n",
    "      mu_pred = params['mu_0']\n",
    "      sigma_pred = params['sigma_0']\n",
    "    else:\n",
    "      mu_pred = F @ mu[t-1]\n",
    "      sigma_pred = F @ sigma[t-1] @ F.T + Q\n",
    "\n",
    "    ###########################################################################\n",
    "    ## TODO for students: compute the filtered state mean and covariance values\n",
    "    # Fill out function and remove\n",
    "    raise NotImplementedError(\"Student excercise: compute the filtered state mean and covariance values\")\n",
    "    ###########################################################################\n",
    "    # write the expression for computing the Kalman gain\n",
    "    K = ...\n",
    "    # write the expression for computing the filtered state mean\n",
    "    mu[t] = ...\n",
    "    # write the expression for computing the filtered state noise covariance\n",
    "    sigma[t] = ...\n",
    "\n",
    "  return mu, sigma\n",
    "\n",
    "\n",
    "# Uncomment below to test your function\n",
    "# filtered_state_means, filtered_state_covariances = kalman_filter(obs, params)\n",
    "# plot_kalman(state, obs, filtered_state_means, title=\"my kf-filter\",\n",
    "#             color='r', label='my kf-filter')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 465
    },
    "colab_type": "text",
    "outputId": "20b04ce7-d931-491a-eca4-cea96f368cb8"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W2D3_DecisionMaking/solutions/W2D3_Tutorial3_Solution_e9df5afe.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=1133 height=414 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W2D3_DecisionMaking/static/W2D3_Tutorial3_Solution_e9df5afe_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 3: Fitting Eye Gaze Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 519
    },
    "colab_type": "code",
    "outputId": "a08ca0cf-a0f1-4294-8da0-3c2af19d4a99"
   },
   "outputs": [],
   "source": [
    "#@title Video 4: Fitting Eye Gaze Data\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"M7OuXmVWHGI\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Tracking eye gaze is used in both experimental and user interface applications. Getting an accurate estimation of where someone is looking on a screen in pixel coordinates can be challenging, however, due to the various sources of noise inherent in obtaining these measurements. A main source of noise is the general accuracy of the eye tracker device itself and how well it maintains calibration over time. Changes in ambient light or subject position can further reduce accuracy of the sensor. Eye blinks introduce a different form of noise as interruptions in the data stream which also need to be addressed.\n",
    "\n",
    "Fortunately we have a candidate solution for handling noisy eye gaze data in the Kalman filter we just learned about. Let's look at how we can apply these methods to a small subset of data taken from the [MIT Eyetracking Database](http://people.csail.mit.edu/tjudd/WherePeopleLook/index.html) [[Judd et al. 2009](http://people.csail.mit.edu/tjudd/WherePeopleLook/Docs/wherepeoplelook.pdf)]. This data was collected as part of an effort to model [visual saliency](http://www.scholarpedia.org/article/Visual_salience) -- given an image, can we predict where a person is most likely going to look."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# load eyetracking data\n",
    "subjects, images = load_eyetracking_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Interactive Demo: Tracking Eye Gaze\n",
    "\n",
    "We have three stimulus images and five different subjects' gaze data. Each subject fixated in the center of the screen before the image appeared, then had a few seconds to freely look around. You can use the widget below to see how different subjects visually scanned the presented image. A subject ID of -1 will show the stimulus images without any overlayed gaze trace. \n",
    "\n",
    "Note that the images are rescaled below for display purposes, they were in their original aspect ratio during the task itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 494,
     "referenced_widgets": [
      "0e4c9c199636484aa4b2106faa4012b6",
      "bba792b331664ab5926612b0b0fede60",
      "d8bad2fdd3c045c9862e04daed8bb245",
      "e90426f956c842fba2d0e50cec1bcdb2",
      "f3ef359b2564474dbf04c9f02adfbec1",
      "fb8f5cea2b404b2fb63c51f75911f282",
      "77d507c867b1409288c08c6eb737a3c6",
      "21b5ad8c052742f59639fd70d9259e1a",
      "7e4f22dd4235408794e8026dc08c2d4c",
      "503b4fe7160240d1873848f5e1fbee87"
     ]
    },
    "colab_type": "code",
    "outputId": "cebabe02-d9fa-4026-98f1-9872850bbc3a"
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "\n",
    "#@markdown Make sure you execute this cell to enable the widget!\n",
    "\n",
    "@widgets.interact(subject_id=widgets.IntSlider(-1, min=-1, max=4),\n",
    "                  image_id=widgets.IntSlider(0, min=0, max=2))\n",
    "def plot_subject_trace(subject_id=-1, image_id=0):\n",
    "  if subject_id == -1:\n",
    "    subject = np.zeros((3, 0, 2))\n",
    "  else:\n",
    "    subject = subjects[subject_id]\n",
    "  data = subject[image_id]\n",
    "  img = images[image_id]\n",
    "\n",
    "  fig, ax = plt.subplots()\n",
    "  ax.imshow(img, aspect='auto')\n",
    "  ax.scatter(data[:, 0], data[:, 1], c='m', s=100, alpha=0.7)\n",
    "  ax.set(xlim=(0, img.shape[1]), ylim=(img.shape[0], 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 3.1: Fitting data with `pykalman`\n",
    "\n",
    "Now that we have data, we'd like to use Kalman filtering to give us a better estimate of the true gaze. Up until this point we've known the parameters of our LDS, but here we need to estimate them from data directly. We will use the `pykalman` package to handle this estimation using the EM algorithm.\n",
    "\n",
    "Before exploring fitting models with `pykalman` it's worth pointing out some naming conventions used by the library:\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "F &: \\texttt{transition_matrices} & \n",
    "Q &: \\texttt{transition_covariance}\\\\\n",
    "H &:\\texttt{observation_matrices} &\n",
    "R &:\\texttt{observation_covariance}\\\\\n",
    "\\mu_0 &: \\texttt{initial_state_mean} & \\Sigma_0 &: \\texttt{initial_state_covariance}\n",
    "\\end{align}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "The first thing we need to do is provide a guess at the dimensionality of the latent state. Let's start by assuming the dynamics line-up directly with the observation data (pixel x,y-coordinates), and so we have a state dimension of 2.\n",
    "\n",
    "We also need to decide which parameters we want the EM algorithm to fit. In this case, we will let the EM algorithm discover the dynamics parameters i.e. the $F$, $Q$, $H$, and $R$ matrices.\n",
    "\n",
    "We set up our `pykalman` `KalmanFilter` object with these settings using the code below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# set up our KalmanFilter object and tell it which parameters we want to\n",
    "# estimate\n",
    "np.random.seed(1)\n",
    "\n",
    "n_dim_obs = 2\n",
    "n_dim_state = 2\n",
    "\n",
    "kf = pykalman.KalmanFilter(\n",
    "  n_dim_state=n_dim_state,\n",
    "  n_dim_obs=n_dim_obs,\n",
    "  em_vars=['transition_matrices', 'transition_covariance',\n",
    "           'observation_matrices', 'observation_covariance']\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Because we know from the reported experimental design that subjects fixated in the center of the screen right before the image appears, we can set the initial starting state estimate $\\mu_0$ as being the center pixel of the stimulus image (the first data point in this sample dataset) with a correspondingly low initial noise covariance $\\Sigma_0$. Once we have everything set, it's time to fit some data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 225
    },
    "colab_type": "code",
    "outputId": "ea9d6b88-6ca3-4704-e58c-51d4efb8b25d"
   },
   "outputs": [],
   "source": [
    "# Choose a subject and stimulus image\n",
    "subject_id = 1\n",
    "image_id = 2\n",
    "data = subjects[subject_id][image_id]\n",
    "\n",
    "# Provide the initial states\n",
    "kf.initial_state_mean = data[0]\n",
    "kf.initial_state_covariance = 0.1*np.eye(n_dim_state)\n",
    "\n",
    "# Estimate the parameters from data using the EM algorithm\n",
    "kf.em(data)\n",
    "\n",
    "print(f'F =\\n{kf.transition_matrices}')\n",
    "print(f'Q =\\n{kf.transition_covariance}')\n",
    "print(f'H =\\n{kf.observation_matrices}')\n",
    "print(f'R =\\n{kf.observation_covariance}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We see that the EM algorithm has found fits for the various dynamics parameters. One thing you will note is that both the state and observation matrices are close to the identity matrix, which means the x- and y-coordinate dynamics are independent of each other and primarily impacted by the noise covariances.\n",
    "\n",
    "We can now use this model to smooth the observed data from the subject. In addition to the source image, we can also see how this model will work with the gaze recorded by the same subject on the other images as well, or even with different subjects.\n",
    "\n",
    "Below are the three stimulus images overlayed with recorded gaze in magenta and smoothed state from the filter in green, with gaze begin (orange triangle) and gaze end (orange square) markers. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 319,
     "referenced_widgets": [
      "660fcc30493b48e2979078a5a90fabd4",
      "fdcf83fdcfe8427fbc670bc45804d3fd",
      "76b462e1e9454b63a414a730d51aaa8b",
      "0e33aa9c1a4c4d73bdb123ef658e491e",
      "a9cdf64c64ef44e7b5b3ebd28efb4c11",
      "c13df8220f9149539275bd7e570970e5",
      "a173cd4dbbca47f3b38576e825ad9a85"
     ]
    },
    "colab_type": "code",
    "outputId": "c5846eaa-77dc-455c-84d0-00672c4d39ba"
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "\n",
    "#@markdown Make sure you execute this cell to enable the widget!\n",
    "\n",
    "@widgets.interact(subject_id=widgets.IntSlider(1, min=0, max=4))\n",
    "def plot_smoothed_traces(subject_id=0):\n",
    "  subject = subjects[subject_id]\n",
    "  fig, axes = plt.subplots(ncols=3, figsize=(18, 4))\n",
    "  for data, img, ax in zip(subject, images, axes):\n",
    "    ax = plot_gaze_data(data, img=img, ax=ax)\n",
    "    plot_kf_state(kf, data, ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Why do you think one trace from one subject was sufficient to provide a decent fit across all subjects? If you were to go back and change the subject_id and/or image_id for when we fit the data using EM, do you think the fits would be different?\n",
    "\n",
    "Finally, recall that the orignial task was to use this data to help devlop models of visual salience. While our Kalman filter is able to provide smooth estimates of observed gaze data, it's not telling us anything about *why* the gaze is going in a certain direction. In fact, if we sample data from our parameters and plot them, we get what amounts to a random walk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 430
    },
    "colab_type": "code",
    "outputId": "45a98d8f-26bc-47de-8a92-18dbcb7e1e5c"
   },
   "outputs": [],
   "source": [
    "kf_state, kf_data = kf.sample(len(data))\n",
    "ax = plot_gaze_data(kf_data, img=images[2])\n",
    "plot_kf_state(kf, kf_data, ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "This should not be surprising, as we have given the model no other observed data beyond the pixels at which gaze was detected. We expect there is some other aspect driving the latent state of where to look next other than just the previous fixation location.\n",
    "\n",
    "In summary, while the Kalman filter is a good option for smoothing the gaze trajectory itself, especially if using a lower-quality eye tracker or in noisy environmental conditions, a linear dynamical system may not be the right way to approach the much more challenging task of modeling visual saliency.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Bonus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Review on Gaussian joint, marginal and conditional distributions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Assume\n",
    "\n",
    "\\begin{eqnarray}\n",
    "z & = & [x^Ty^T]^T \\\\\n",
    "z & = & \\begin{bmatrix}x \\\\y\\end{bmatrix}\\sim N\\left(\\begin{bmatrix}a \\\\b\\end{bmatrix}, \\begin{bmatrix}A & C \\\\C^T & B\\end{bmatrix}\\right)\n",
    "\\end{eqnarray}\n",
    "\n",
    "then the marginal distributions are\n",
    "\n",
    "\\begin{eqnarray}\n",
    "x & \\sim & N(a, A) \\\\\n",
    "y & \\sim & N(b,B)\n",
    "\\end{eqnarray}\n",
    "\n",
    "and the conditional distributions are\n",
    "\n",
    "\\begin{eqnarray}\n",
    "x|y & \\sim & N(a+CB^{-1}(y-b), A-CB^{-1}C^T) \\\\\n",
    "y|x & \\sim & N(b+C^TA^{-1}(x-a), B-C^TA^{-1}C)\n",
    "\\end{eqnarray}\n",
    "\n",
    "*important take away: given the joint Gaussian distribution we can derive the conditionals*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Kalman Smoothing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 519
    },
    "colab_type": "code",
    "outputId": "9b06ef4d-2b47-4281-8335-e49bc386a11f"
   },
   "outputs": [],
   "source": [
    "#@title Video 5: Kalman Smoothing and the EM Algorithm\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"4Ar2mYz1Nms\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Obtain estimates by propagating from $y_T$ back to $y_0$ using results of forward pass ($\\hat{\\mu}_t^{filter}, \\hat{\\Sigma}_t^{filter}, P_t=\\hat{\\Sigma}_{t+1}^{pred}$)\n",
    "\n",
    "\\begin{eqnarray}\n",
    "s_t & \\sim & N(\\hat{\\mu}_t^{smooth}, \\hat{\\Sigma}_t^{smooth}) \\\\\n",
    "\\hat{\\mu}_t^{smooth} & = & \\hat{\\mu}_t^{filter}+J_t(\\hat{\\mu}_{t+1}^{smooth}-F\\hat{\\mu}_t^{filter}) \\\\\n",
    "\\hat{\\Sigma}_t^{smooth} & = & \\hat{\\Sigma}_t^{filter}+J_t(\\hat{\\Sigma}_{t+1}^{smooth}-P_t)J_t^T \\\\\n",
    "J_t & = & \\hat{\\Sigma}_t^{filter}F^T P_t^{-1}\n",
    "\\end{eqnarray}\n",
    "\n",
    "This gives us the final estimate for $z_t$.\n",
    "\n",
    "\\begin{eqnarray}\n",
    "\\hat{\\mu}_t & = & \\hat{\\mu}_t^{smooth} \\\\\n",
    "\\hat{\\Sigma}_t & = & \\hat{\\Sigma}_t^{smooth}\n",
    "\\end{eqnarray}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Exercise 3: Implement Kalman smoothing\n",
    "\n",
    "In this exercise you will implement the Kalman smoothing (backward) process. Again you will focus on writing the expressions for computing the smoothed mean, smoothed covariance, and $J_t$ values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def kalman_smooth(data, params):\n",
    "  \"\"\" Perform Kalman smoothing (backward pass) on the data given the provided\n",
    "  system parameters.\n",
    "\n",
    "  Args:\n",
    "    data (ndarray): a sequence of osbervations of shape(n_timesteps, n_dim_obs)\n",
    "    params (dict): a dictionary of model paramters: (F, Q, H, R, mu_0, sigma_0)\n",
    "\n",
    "  Returns:\n",
    "    ndarray, ndarray: the smoothed system means and noise covariance values\n",
    "  \"\"\"\n",
    "  # pulled out of the params dict for convenience\n",
    "  F = params['F']\n",
    "  Q = params['Q']\n",
    "  H = params['H']\n",
    "  R = params['R']\n",
    "\n",
    "  n_dim_state = F.shape[0]\n",
    "  n_dim_obs = H.shape[0]\n",
    "\n",
    "  # first run the forward pass to get the filtered means and covariances\n",
    "  mu, sigma = kalman_filter(data, params)\n",
    "\n",
    "  # initialize state mean and covariance estimates\n",
    "  mu_hat = np.zeros_like(mu)\n",
    "  sigma_hat = np.zeros_like(sigma)\n",
    "  mu_hat[-1] = mu[-1]\n",
    "  sigma_hat[-1] = sigma[-1]\n",
    "\n",
    "  # smooth the data\n",
    "  for t in reversed(range(len(data)-1)):\n",
    "    sigma_pred = F @ sigma[t] @ F.T + Q  # sigma_pred at t+1\n",
    "    ###########################################################################\n",
    "    ## TODO for students: compute the smoothed state mean and covariance values\n",
    "    # Fill out function and remove\n",
    "    raise NotImplementedError(\"Student excercise: compute the smoothed state mean and covariance values\")\n",
    "    ###########################################################################\n",
    "\n",
    "    # write the expression to compute the Kalman gain for the backward process\n",
    "    J = ...\n",
    "    # write the expression to compute the smoothed state mean estimate\n",
    "    mu_hat[t] = ...\n",
    "    # write the expression to compute the smoothed state noise covariance estimate\n",
    "    sigma_hat[t] = ...\n",
    "\n",
    "  return mu_hat, sigma_hat\n",
    "\n",
    "\n",
    "# Uncomment once the kalman_smooth function is complete\n",
    "# smoothed_state_means, smoothed_state_covariances = kalman_smooth(obs, params)\n",
    "# axes = plot_kalman(state, obs, filtered_state_means, color=\"r\",\n",
    "#                    label=\"my kf-filter\")\n",
    "# plot_kalman(state, obs, smoothed_state_means, color=\"b\",\n",
    "#             label=\"my kf-smoothed\", axes=axes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 465
    },
    "colab_type": "text",
    "outputId": "d07e06a3-7306-4491-fefb-b293c092b1fc"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W2D3_DecisionMaking/solutions/W2D3_Tutorial3_Solution_a0f4822b.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=1133 height=414 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W2D3_DecisionMaking/static/W2D3_Tutorial3_Solution_a0f4822b_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Forward vs Backward**\n",
    "\n",
    "Now that we have implementations for both, let's compare their peformance by computing the MSE between the filtered (forward) and smoothed (backward) estimated states and the true latent state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "code",
    "outputId": "0c29490d-a511-433f-c542-bb3459f90941"
   },
   "outputs": [],
   "source": [
    "print(f\"Filtered MSE: {np.mean((state - filtered_state_means)**2):.3f}\")\n",
    "print(f\"Smoothed MSE: {np.mean((state - smoothed_state_means)**2):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "In this example, the smoothed estimate is clearly superior to the filtered one. This makes sense as the backward pass is able to use the forward pass estimates and correct them given all the data we've collected.\n",
    "\n",
    "So why would you ever use Kalman filtering alone, without smoothing? As Kalman filtering only depends on already observed data (i.e. the past) it can be run in a streaming, or on-line, setting. Kalman smoothing relies on future data as it were, and as such can only be applied in a batch, or off-line, setting. So use Kalman filtering if you need real-time corrections and Kalman smoothing if you are considering already-collected data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## The Expectation-Maximization (EM) Algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "- want to maximize $log p(y|\\theta)$\n",
    "\n",
    "- need to marginalize out latent state *(which is not tractable)*\n",
    "\n",
    "$$p(y|\\theta)=\\int p(y,s|\\theta)dz$$\n",
    "\n",
    "- add a probability distribution $q(s)$ which will approximate the latent state distribution \n",
    "\n",
    "$$log p(y|\\theta)\\int_s q(s)dz$$\n",
    "\n",
    "- can be rewritten as\n",
    "\n",
    "$$\\mathcal{L}(q,\\theta)+KL\\left(q(s)||p(s|y),\\theta\\right)$$\n",
    "\n",
    "- $\\mathcal{L}(q,\\theta)$ contains the joint distribution of $y$ and $s$\n",
    "\n",
    "- $KL(q||p)$ contains the conditional distribution of $s|y$\n",
    "\n",
    "#### Expectation step\n",
    "- parameters are kept fixed\n",
    "- find a good approximation $q(s)$: maximize lower bound $\\mathcal{L}(q,\\theta)$ with respect to $q(s)$\n",
    "- (already implemented Kalman filter+smoother)\n",
    "\n",
    "#### Maximization step\n",
    "- keep distribution $q(s)$ fixed\n",
    "- change parameters to maximize the lower bound $\\mathcal{L}(q,\\theta)$\n",
    "\n",
    "As mentioned, we have already effectively solved for the E-Step with our Kalman filter and smoother. The M-step requires further derivation, which is covered in the Appendix. Rather than having you implement the M-Step yourselves, let's instead turn to using a library that has already implemented EM for exploring some experimental data from cognitive neuroscience.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### The M-step for a LDS\n",
    "*(see Bishop, chapter 13.3.2 Learning in LDS)*\n",
    "Update parameters of the probability distribution\n",
    "\n",
    "*For the updates in the M-step we will need the following posterior marginals obtained from the Kalman smoothing results* $\\hat{\\mu}_t^{smooth}, \\hat{\\Sigma}_t^{smooth}$\n",
    "\n",
    "$$\n",
    "\\begin{eqnarray}\n",
    "E(s_t) &=& \\hat{\\mu}_t \\\\\n",
    "E(s_ts_{t-1}^T) &=& J_{t-1}\\hat{\\Sigma}_t+\\hat{\\mu}_t\\hat{\\mu}_{t-1}^T\\\\\n",
    "E(s_ts_{t}^T) &=& \\hat{\\Sigma}_t+\\hat{\\mu}_t\\hat{\\mu}_{t}^T\n",
    "\\end{eqnarray}\n",
    "$$\n",
    "\n",
    "**Update parameters**\n",
    "\n",
    "Initial parameters\n",
    "$$\n",
    "\\begin{eqnarray}\n",
    "\\mu_0^{new}&=& E(s_0)\\\\\n",
    "Q_0^{new} &=& E(s_0s_0^T)-E(s_0)E(s_0^T) \\\\\n",
    "\\end{eqnarray}\n",
    "$$\n",
    "\n",
    "Hidden (latent) state parameters\n",
    "$$\n",
    "\\begin{eqnarray}\n",
    "F^{new} &=& \\left(\\sum_{t=2}^N E(s_ts_{t-1}^T)\\right)\\left(\\sum_{t=2}^N E(s_{t-1}s_{t-1}^T)\\right)^{-1} \\\\\n",
    "Q^{new} &=& \\frac{1}{T-1} \\sum_{t=2}^N E\\big(s_ts_t^T\\big) - F^{new}E\\big(s_{t-1}s_{t}^T\\big) - E\\big(s_ts_{t-1}^T\\big)F^{new}+F^{new}E\\big(s_{t-1}s_{t-1}^T\\big)\\big(F^{new}\\big)^{T}\\\\\n",
    "\\end{eqnarray}\n",
    "$$\n",
    "\n",
    "Observable (measured) space parameters\n",
    "$$H^{new}=\\left(\\sum_{t=1}^N y_t E(s_t^T)\\right)\\left(\\sum_{t=1}^N E(s_t s_t^T)\\right)^{-1}$$\n",
    "$$R^{new}=\\frac{1}{T}\\sum_{t=1}^Ny_ty_t^T-H^{new}E(s_t)y_t^T-y_tE(s_t^T)H^{new}+H^{new}E(s_ts_t^T)H_{new}$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Handling Eye Blinks\n",
    "\n",
    "In the MIT Eyetracking Database, raw tracking data includes times when the subject blinked. The way this is represented in the data stream is via negative pixel coordinate values.\n",
    "\n",
    "We could try to mitigate these samples by simply deleting them from the stream, though this introduces other issues. For instance, if each sample corresponds to a fixed time step, and you arbitrarily remove some samples, the integrity of that consistent timestep between samples is lost. It's sometimes better to flag data as missing rather than to pretend it was never there at all, especially with time series data.\n",
    "\n",
    "Another solution is to used masked arrays. In `numpy`, a [masked array](https://numpy.org/doc/stable/reference/maskedarray.generic.html#what-is-a-masked-array) is an `ndarray` with an additional embedded boolean masking array that indicates which elements should be masked. When computation is performed on the array, the masked elements are ignored. Both `matplotlib` and `pykalman` work with masked arrays, and, in fact, this is the approach taken with the data we explore in this notebook. \n",
    "\n",
    "In preparing the dataset for this noteook, the original dataset was preprocessed to set all gaze data as masked arrays, with the mask enabled for any pixel with a negative x or y coordinate."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W2D3_Tutorial3",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
